# -*- coding: utf-8 -*-
"""
Created on Tue Jul 14 10:59:59 2020

@author: 81283
"""

import numpy as np
from agent.agent import get_next_s
from agent.agent import update_theta
from agent.agent import simple_convert_into_pi_from_theta
from agent.agent import reset


def goal_maze():
    s = 0  # 出发地点
    state_history = [[0, np.nan]]  # 移动的历史记录

    while (1):  # 
        action, next_s = get_next_s(s)
        state_history[-1][1]=action
        
        state_history.append([next_s, np.nan])  
        

        if next_s == 8:  #
            break
        else:
            s = next_s

    return state_history


def once():
    stop_epsilon = 10**-4
    
    for i in range(10000):
        state_history = goal_maze() # 移动的历史记录
        # 更新参数
        update_theta(state_history)
        #更新策略
        delta_pi=simple_convert_into_pi_from_theta()
        
        # print(delta_pi)  # 策略的变化值
        # print("求解迷宫问题本次走的步数："+str(len(state_history)-1))

        if np.sum(delta_pi) < stop_epsilon:
            print("共进行的实验次数为："+str(i))
            break
        
    return i, len(state_history)-1
    

if __name__=="__main__":
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    from collections import defaultdict

    mpl.rcParams['font.sans-serif']=['SimHei']
    mpl.rcParams['axes.unicode_minus']=False 
    
    plt.grid(True, axis='y',ls=':',color='r',alpha=0.3)

    # 各次实验的步数历史
    states=[]
    
    i=0
    # 共进行100次实验
    for _ in range(10000):
        trials, steps_len = once() # 移动的历史记录
        
        if steps_len!=4:
            i+=1
            print("实验次数："+str(trials)+"  探索步数："+str(steps_len))
        else:
            states.append(trials)
        
        reset()###参数，策略初始化
        
    # s_list 步数及实验的次数的列表 [(步数a,次数a),(步数b,次数b),...]
    s_dict=defaultdict(int)
    for x in states:
        s_dict[x]+=1
    s_list=list(s_dict.items())
    s_list.sort(key=lambda x:x[0])
    
    plt.bar([x for x in range(len(s_list))],[x[1] for x in s_list],align='center', color='b', tick_label=[str(x[0]) for x in s_list],
            alpha=0.6, edgecolor="black")
    
    plt.xlabel('策略迭代法求解迷宫问题的实验次数')
    plt.ylabel('不同实验次数的个数')

    plt.show()